## make sure same labels for training and testing categories

all_numeric <- dat %>% 
  left_join(reference_schtype) %>% 
  left_join(reference_school) %>% 
  left_join(reference_district) %>% 
  select(-district, -school_name, -fac_schtype)

all_numeric_test <- dat_2018 %>% 
  left_join(reference_schtype) %>% 
  left_join(reference_school) %>% 
  left_join(reference_district) %>% 
  select(-district, -school_name, -fac_schtype) %>% 
  drop_na()
# Use ALL data because we have another year as our final test data

train_data <- all_numeric %>%
  select(-exclusionary)

## changing the year column to just 1:3 because i have 3 years. 

train_data[,15] <-c(rep(1, nrow(dat_2015)), rep(2, nrow(dat_2016)), rep(3, nrow(dat_2017))) 

train_targets <- all_numeric %>%
  select(exclusionary) %>%
  unlist(., use.names = F)

test_data <- all_numeric_test %>%
  select(-exclusionary) 

## same year representation as in train
test_data[,15] <- c(rep(4, nrow(test_data)))

test_targets <- all_numeric_test %>% 
  select(exclusionary) %>% 
  unlist(., use.names = F)
train_scale <- 
  scale(train_data) %>% 
  as.matrix() 
## center and scale test based on the train data
## NOTE: originally, I scaled train and test SEPARATELY
## but found this function and re-scaled test based on 
## train's mean and sd

## this is a `{caret}` based solution:

normParam <- preProcess(train_data)
test_scale <- predict(normParam, test_data) %>% as.matrix()
## original way I scaled: 

# test_scale <- 
#   scale(test_data) %>% 
#   as.matrix()

## if you do it this way, you've got to put
## the rep(4) after scaling, otherwise the lack of variance in 2018 is issue.
## when no variance in a model, scale() gives NaN
## then, we get NaN returned for whole model.
model <- keras_model_sequential() %>%
    layer_dense(units = 64, activation = "relu",
                input_shape = 18, 
                kernel_regularizer = regularizer_l1_l2(l1 = 0.001, l2 = 0.001)) %>%
    layer_dropout(rate = 0.25) %>%  
    layer_dense(units = 64, activation = "relu", 
                kernel_regularizer = regularizer_l1_l2(l1 = 0.001, l2 = 0.001)) %>%
    layer_dropout(rate = 0.25) %>%  
    layer_dense(units = 64, activation = "relu", 
                kernel_regularizer = regularizer_l1_l2(l1 = 0.001, l2 = 0.001)) %>%
    layer_dropout(rate = 0.25) %>%  
    layer_dense(units = 1)

model %>% compile(
    optimizer = "rmsprop",
    loss = "mse",
    metrics = c("mae")
  ) 

mean absolute error (MAE) = absolute value of the difference between the predictions and the targets.

history <- model %>% 
  fit(train_scale, 
      train_targets,
      epochs = 25, 
      batch_size = 1, 
      verbose = 1)
history
## Trained on 7,614 samples (batch_size=1, epochs=25)
## Final epoch (plot to see history):
## loss: 11,737
##  mae: 53.23
evaluation_metrics <- model %>%
  evaluate(test_scale, test_targets)
evaluation_metrics
## $loss
## [1] 5648.562
## 
## $mae
## [1] 41.01249
# model %>% 
#   save_model_hdf5("tuned_model_2.h5")
pred_wrapper <- function(object, newdata) {
  predict(object, x = as.matrix(newdata)) %>%
    as.vector()
}

p1 <- vip(
  object = model,                     # fitted model
  method = "permute",                 # permutation-based VI scores
  num_features = ncol(train_scale),       # default only plots top 10 features
  pred_wrapper = pred_wrapper,            # user-defined prediction function
  train = train_scale,                      # training data
  target = train_targets,                   # response values used for training
  metric = "mse",                # evaluation metric
  progress = "text"                 # request a text-based progress bar
)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |====                                                                  |   6%
  |                                                                            
  |========                                                              |  11%
  |                                                                            
  |============                                                          |  17%
  |                                                                            
  |================                                                      |  22%
  |                                                                            
  |===================                                                   |  28%
  |                                                                            
  |=======================                                               |  33%
  |                                                                            
  |===========================                                           |  39%
  |                                                                            
  |===============================                                       |  44%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |=======================================                               |  56%
  |                                                                            
  |===========================================                           |  61%
  |                                                                            
  |===============================================                       |  67%
  |                                                                            
  |===================================================                   |  72%
  |                                                                            
  |======================================================                |  78%
  |                                                                            
  |==========================================================            |  83%
  |                                                                            
  |==============================================================        |  89%
  |                                                                            
  |==================================================================    |  94%
  |                                                                            
  |======================================================================| 100%
print(p1)  # display plot

## elemetary, middle, high, or k-12
p2 <- partial(model, pred.var = "level_schtype", pred.fun = pred_wrapper, 
              train = train_scale) %>%
  autoplot(alpha = 0.1)

## total number of students in school
p3 <- partial(model, pred.var = "total", pred.fun = pred_wrapper, 
              train = train_scale) %>%
  autoplot(alpha = 0.1)
grid.arrange(p2, p3, ncol = 2)  # display plots side by side

## school level percent frl
p4 <- partial(model, pred.var = "perc_frl", pred.fun = pred_wrapper, 
              train = train_scale) %>%
  autoplot(alpha = 0.1)

## percent white in school
p5 <- partial(model, pred.var = "perc_white", pred.fun = pred_wrapper, 
              train = train_scale) %>%
  autoplot(alpha = 0.1)

grid.arrange(p4, p5, ncol = 2)  # display plots side by side

## school percent black
p6 <- partial(model, pred.var = "perc_black", pred.fun = pred_wrapper, 
              train = train_scale) %>%
  autoplot(alpha = 0.1)

## district percent frl
p7 <- partial(model, pred.var = "d_frl", pred.fun = pred_wrapper, 
              train = train_scale) %>%
  autoplot(alpha = 0.1)
grid.arrange(p6, p7, ncol = 2)  # display plots side by side

## the factor that was all schools'

p8 <- partial(model, pred.var = "level_school", pred.fun = pred_wrapper, 
              train = train_scale) %>%
  autoplot(alpha = 0.1)
p8